查看原文
其他

tensorflow中实现神经网络训练手写数字数据集mnist

OpenCV使者 OpenCV学堂 2019-03-28

tensorflow中实现神经网络训练手写数字数据集mnist

一:网络结构

基于tensorflow实现一个简单的三层神经网络,并使用它训练mnist数据集,神经网络三层分别为:

输入层: 

像素数据输入28x28=784 个输入节点

隐藏层:

30个神经元节点

输出层:

10个神经元节点,对应 0 ~ 9 十个数字 

图示结构如下:

网络结构的代码实现:

  1. hidden_nodes = 30

  2. x = tf.placeholder(shape=[None, 784], dtype=tf.float32)

  3. y = tf.placeholder(shape=[None, 10], dtype=tf.float32)

  4. w1 = tf.Variable(tf.truncated_normal(shape=[784, hidden_nodes]), dtype=tf.float32)

  5. b1 = tf.Variable(tf.truncated_normal(shape=[1, hidden_nodes]), dtype=tf.float32)

  6. w2 = tf.Variable(tf.truncated_normal(shape=[hidden_nodes, 10]), dtype=tf.float32)

  7. b2 = tf.Variable(tf.truncated_normal(shape=[1, 10]), dtype=tf.float32)

  8. # layer hidden

  9. nn_1 = tf.add(tf.matmul(x, w1), b1)

  10. h1 = tf.nn.sigmoid(nn_1)

  11. # layer output

  12. nn_2 = tf.add(tf.matmul(h1, w2), b2)

  13. out = tf.nn.sigmoid(nn_2)

  14. # loss function

  15. error = tf.square(tf.subtract(y, out))

  16. loss = tf.reduce_sum(error)

  17. # back prop

  18. step = tf.train.GradientDescentOptimizer(0.05).minimize(loss)

  19. init = tf.global_variables_initializer()

二:数据读取与训练

读取mnist数据集

  1. from tensorflow.examples.tutorials.mnist import inputdata

  2. mnist = inputdata.readdatasets("MNISTdata/", onehot=True)

如果不行,就下载下来,放到本地即可

执行训练的代码如下

  1. # accurate  model

  2. acc_mat = tf.equal(tf.argmax(out, 1), tf.argmax(y, 1))

  3. acc = tf.reduce_sum(tf.cast(acc_mat, tf.float32))

  4. with tf.Session() as sess:

  5.    sess.run(init)

  6.    for i in range(20000):

  7.        batch_xs, batch_ys = mnist.train.next_batch(10)

  8.        sess.run(step, feed_dict={x: batch_xs, y: batch_ys})

  9.        if i % 1000 == 0:

  10.            x_input = mnist.test.images[:1000]

  11.            y_input = mnist.test.labels[:1000]

  12.            curr_acc = sess.run(acc, feed_dict={x: x_input, y: y_input})

  13.            print("current acc : ", curr_acc)

训练结果:

测试集上对1000张手写数字图像测试正确识别921张,准确率高达92.1%。说明传统的人工神经网络表现还是不错的,这个还是在没有优化的情况下,通过修改批量数大小,修改学习率,添加隐藏层节点数与dropout正则化,可以更进一步提高识别率。



上次送书活动,感谢大家踊跃发言,留言,然图书只有三本,留言前三名

- 门德尔松

- 王健行

- 水亦心

截图为证:


请在微信公众号上,发送【本人微信号】,有效期至2018-07-14日24:00截至。过期作废!其它人可以到【京东】购买本人图书,本人一定做好答疑服务,再次感谢大家的支持与赞扬!


知不足者好学

耻下问者自满

关注【OpenCV学堂】

长按或者扫码二维码即可关注

更多相关阅读

如何学习计算机视觉OpenCV

OpenCV实现0到9数字识别OCR

TensorFlow中常量与变量的基本操作演示

TensorFlow中的feed与fetch

TensorFlow进行简单的图像处理

基于OpenCV实现手写体数字训练与识别


    您可能也对以下帖子感兴趣

    文章有问题?点此查看未经处理的缓存